// @HEADER
//*********************************************************************//
//  SiYuan: A numerical PDE solver                                     //
//  Copyright (2022) YUAN Xi                                           //
//  This Software is released under the BSD 2-Clause license detailed  //
//  in the file "LICENSE" in the top-level SiYuan directory            //
//*********************************************************************//
// @HEADER

#ifndef _ROLSOLVER_IMPL_HPP
#define _ROLSOLVER_IMPL_HPP

#include "Teuchos_TestForException.hpp"
#include "Teuchos_ScalarTraits.hpp"

#include "Thyra_TpetraThyraWrappers.hpp"

#include "Tpetra_Map.hpp"
#include "Tpetra_CrsGraph.hpp"
#include "Tpetra_CrsMatrix.hpp"
#include "Tpetra_Import.hpp"
#include "Tpetra_Export.hpp"
#include "TpetraExt_MatrixMatrix.hpp"

namespace SiYuan {

// Initializers/Accessors

template<class Scalar>
ROLSolver<Scalar>::ROLSolver( Teuchos::ParameterList& appParams,
  Teuchos::RCP<Thyra::ModelEvaluatorDefaultBase<Scalar> > piroModel,
  Teuchos::RCP< Piro::ROL_ObserverBase<Scalar> > observer
  ) : model_(piroModel)
{
  auto analysisParams = appParams.sublist("Analysis");
  rolParams = analysisParams.sublist("ROL");
  
  x_space_ = model_->get_x_space();
  
  int verbose = rolParams.get<int>("Verbosity Level", 3);
  Teuchos::EVerbosityLevel verbosityLevel;
  switch(verbose) {
    case 1: verbosityLevel= Teuchos::VERB_LOW; break;
    case 2: verbosityLevel= Teuchos::VERB_MEDIUM; break;
    case 3: verbosityLevel= Teuchos::VERB_HIGH; break;
    case 4: verbosityLevel= Teuchos::VERB_EXTREME; break;
    default: verbosityLevel= Teuchos::VERB_NONE;
  }
  
  Teuchos::RCP<Thyra::ModelEvaluatorDefaultBase<double>> model;
  Teuchos::RCP<Piro::SteadyStateSolver<double>> piroSSSolver;
  auto piroNOXSolver = Teuchos::rcp_dynamic_cast<Piro::NOXSolver<double>>(piroModel);
  if(Teuchos::nonnull(piroNOXSolver)) {
    piroSSSolver = Teuchos::rcp_dynamic_cast<Piro::SteadyStateSolver<double>>(piroNOXSolver);
    model = Teuchos::rcp_dynamic_cast<Thyra::ModelEvaluatorDefaultBase<double>>(piroNOXSolver->getSubModel());
  } else {
    TEUCHOS_TEST_FOR_EXCEPTION(true, Teuchos::Exceptions::InvalidParameter,
        std::endl << "Error in Piro::PerformROLAnalysis: " <<
        "only Piro::NOXSolver is currently supported for piroModel"<<std::endl);
  }
  
  out = Teuchos::VerboseObjectBase::getDefaultOStream();
  int g_index = rolParams.get<int>("Response Vector Index", 0);
  
  num_p_ = rolParams.get<int>("Number Of Parameters", 1);
  p_indices_.resize(num_p_);
  std::vector<std::string> p_names;

  for(int i=0; i<num_p_; ++i) {
    std::ostringstream ss; ss << "Parameter Vector Index " << i;
    p_indices_[i] = rolParams.get<int>(ss.str(), i);
    const auto names_array = *piroSSSolver->getModel().get_p_names(p_indices_[i]);
    for (int k=0; k<names_array.size(); k++) {
      p_names.push_back(names_array[k]);
    }
  }
  
  //set names of parameters in the "Optimization Status" sublist
  appParams.sublist("Optimization Status").set("Parameter Names", Teuchos::rcpFromRef(p_names));

  Teuchos::Array<Teuchos::RCP<Thyra::VectorSpaceBase<double> const>> p_spaces(num_p_);
  Teuchos::Array<Teuchos::RCP<Thyra::VectorBase<double>>> p_vecs(num_p_);
  for (auto i = 0; i < num_p_; ++i) {
    p_spaces[i] = model->get_p_space(p_indices_[i]);
    p_vecs[i] = Thyra::createMember(p_spaces[i]);
  }
  p_space_ = Thyra::productVectorSpace<double>(p_spaces);
  Teuchos::RCP<Thyra::DefaultProductVector<double>> p_prod = Thyra::defaultProductVector<double>(p_space_, p_vecs());
  result_ = p_prod;

  //  p = Thyra::createMember(piroModel.get_p_space(p_index));

  for (auto i = 0; i < num_p_; ++i) {
    Teuchos::RCP<const Thyra::VectorBase<double> > p_init = model->getNominalValues().get_p(p_indices_[i]);
    Thyra::copy(*p_init, p_prod->getNonconstVectorBlock(i).ptr());
  }
  
  ROL::ThyraVector<double> rol_p(p_prod);
  //Teuchos::RCP<Thyra::VectorSpaceBase<double> const> p_space;

  x_ = Thyra::createMember(x_space_);
  Thyra::copy(*model->getNominalValues().get_x(), x_.ptr());

  ROL::ThyraVector<double> rol_x(x_);
  Teuchos::RCP<Thyra::VectorBase<double>> lambda_vec = Thyra::createMember(x_space_);
  ROL::ThyraVector<double> rol_lambda(lambda_vec);

  Piro::ThyraProductME_Objective_SimOpt<double> obj(*model, g_index, p_indices_, appParams, verbosityLevel, observer);
  Piro::ThyraProductME_Constraint_SimOpt<double> constr(*model, g_index, p_indices_, appParams, verbosityLevel, observer);

  constr.setSolveParameters(rolParams.sublist("ROL Options"));

  if(rolParams.isParameter("Use NOX Solver") && rolParams.get<bool>("Use NOX Solver"))
    constr.setExternalSolver(piroModel);
  constr.setNumResponses(piroSSSolver->num_g());


  obj_ptr_ = ROL::makePtrFromRef(obj);
  constr_ptr_ = ROL::makePtrFromRef(constr);

  rol_p_ptr_ = ROL::makePtrFromRef(rol_p);
  rol_x_ptr_ = ROL::makePtrFromRef(rol_x);
  ROL::Ptr<ROL::Vector<double> > rol_lambda_ptr = ROL::makePtrFromRef(rol_lambda);
  reduced_obj_ptr_ = ROL::makePtr<ROL::Reduced_Objective_SimOpt<double>>(obj_ptr_,constr_ptr_,rol_x_ptr_,rol_p_ptr_,rol_lambda_ptr);

  print_ = rolParams.get<bool>("Print Output", false);

  int seed = rolParams.get<int>("Seed For Thyra Randomize", 42);

  //! set initial guess (or use the one provided by the Model Evaluator)
  std::string init_guess_type = rolParams.get<std::string>("Parameter Initial Guess Type", "From Model Evaluator");
  if(init_guess_type == "Uniform Vector")
    rol_p.putScalar(rolParams.get<double>("Uniform Parameter Guess", 1.0));
  else if(init_guess_type == "Random Vector") {
    Teuchos::Array<double> minmax(2); minmax[0] = -1; minmax[1] = 1;
    minmax = rolParams.get<Teuchos::Array<double> >("Min And Max Of Random Parameter Guess", minmax);
    ::Thyra::randomize<double>( minmax[0], minmax[1], rol_p.getVector().ptr());
  }
  else if(init_guess_type != "From Model Evaluator") {
    TEUCHOS_TEST_FOR_EXCEPTION(true, Teuchos::Exceptions::InvalidParameter,
              std::endl << "Error in Piro::PerformROLAnalysis: " <<
              "Parameter Initial Guess Type \"" << init_guess_type << "\" is not Known.\nValid options are: \"Parameter Scalar Guess\", \"Uniform Vector\" and \"Random Vector\""<<std::endl);
  }
  
  //! test thyra implementation of ROL vector
  if(rolParams.get<bool>("Test Vector", false)) {
    Teuchos::RCP<Thyra::VectorBase<double> > rand_vec_x = result_->clone_v();
    Teuchos::RCP<Thyra::VectorBase<double> > rand_vec_y = result_->clone_v();
    Teuchos::RCP<Thyra::VectorBase<double> > rand_vec_z = result_->clone_v();
    ::Thyra::seed_randomize<double>( seed );

    int num_tests = rolParams.get<int>("Number Of Vector Tests", 1);

    for(int i=0; i< num_tests; i++) {

      *out << "\nROL performing vector test " << i+1 << " of " << num_tests << std::endl;

      ::Thyra::randomize<double>( -1.0, 1.0, rand_vec_x.ptr());
      ::Thyra::randomize<double>( -1.0, 1.0, rand_vec_y.ptr());
      ::Thyra::randomize<double>( -1.0, 1.0, rand_vec_z.ptr());

      ROL::ThyraVector<double> rol_x(rand_vec_x);
      ROL::ThyraVector<double> rol_y(rand_vec_y);
      ROL::ThyraVector<double> rol_z(rand_vec_z);

      rol_x.checkVector(rol_y, rol_z,print_, *out);
    }
  }
  
  useFullSpace = rolParams.get("Full Space",false);

  *out << "\nROL options:" << std::endl;
  rolParams.sublist("ROL Options").print(*out);
  *out << std::endl;


  ROL::Ptr<ROL::StatusTest<double>> status = ROL::makePtr<ROL::StatusTest<double>>(rolParams.sublist("ROL Options"));
  ROL::Ptr<ROL::Step<double>> step;
  if(rolParams.get<std::string>("Step Method", "Line Search") == "Line Search")
    step = ROL::makePtr<ROL::LineSearchStep<double>>(rolParams.sublist("ROL Options"));
  else
    step = ROL::makePtr<ROL::TrustRegionStep<double>>(rolParams.sublist("ROL Options"));
  algo_ = ROL::makePtr<ROL::Algorithm<double>>(step, status, true);

  bool useHessianDotProduct = false;
  Teuchos::ParameterList hessianDotProductList;
  if(rolParams.isSublist("Matrix Based Dot Product")) {
    const Teuchos::ParameterList& matrixDotProductList = rolParams.sublist("Matrix Based Dot Product");
    auto matrixType = matrixDotProductList.get<std::string>("Matrix Type");
    if(matrixType == "Hessian Of Response") {
      useHessianDotProduct = true;
      hessianDotProductList = matrixDotProductList.sublist("Matrix Types").sublist("Hessian Of Response");
    }
    else if (matrixType == "Identity")
      useHessianDotProduct = false;
    else {
      TEUCHOS_TEST_FOR_EXCEPTION(true, Teuchos::Exceptions::InvalidParameter,
          std::endl << "Error in Piro::PerformROLAnalysis: " <<
          "Matrix Type not recognized. Available options are: \n" <<
          "\"Identity\" and \"Hessian Of Response\""<<std::endl);
    }
  }



  boundConstrained = rolParams.get<bool>("Bound Constrained", false);
  
 /* MEB::InArgsSetup<Scalar> inArgs = model->createInArgs();;
  MEB::OutArgsSetup<Scalar> outArgs = model->createOutArgs();;
  num_p_ = inArgs.Np();   // Number of *vectors* of parameters
  num_g_ = outArgs.Ng();  // Number of *vectors* of responses

  x_space_ = model->get_x_space();*/

}


// Public functions overridden from ModelEvaulator

template<class Scalar>
Teuchos::RCP<const Thyra::VectorSpaceBase<Scalar> >
ROLSolver<Scalar>::get_p_space(int l) const
{
  return model_->get_p_space(l);
}


template<class Scalar>
Teuchos::RCP<const Teuchos::Array<std::string> >
ROLSolver<Scalar>::get_p_names(int j) const
{
  return model_->get_p_names(j);
}

template<class Scalar>
Teuchos::RCP<const Thyra::VectorSpaceBase<Scalar> >
ROLSolver<Scalar>::get_g_space(int j) const
{
  if (j == num_g_) return model_->get_x_space(); //last response vector is solution (same map as x)
  else return model_->get_g_space(j);
}

template<class Scalar>
Teuchos::ArrayView<const std::string>
ROLSolver<Scalar>::get_g_names(int l) const
{
  return model_->get_g_names(l);
}


template<class Scalar>
Thyra::ModelEvaluatorBase::InArgs<Scalar>
ROLSolver<Scalar>::createInArgs() const
{
  MEB::InArgsSetup<Scalar> inArgs;
  inArgs.setModelEvalDescription(this->description());
  inArgs.set_Np(num_p_);
  return inArgs;
}


// Private functions overridden from ModelEvaulatorDefaultBase

template<class Scalar>
Thyra::ModelEvaluatorBase::OutArgs<Scalar>
ROLSolver<Scalar>::createOutArgsImpl() const
{
  MEB::OutArgsSetup<Scalar> outArgs;
  outArgs.setModelEvalDescription(this->description());

  // Ng is 1 bigger then model's Ng so that the solution vector can be an outarg
  outArgs.set_Np_Ng(num_p_, num_g_+1);

  //Derivative info 
  MEB::OutArgsSetup<Scalar> model_outArgs = model_->createOutArgs();
  for (int i=0; i<num_g_; i++) {
    for (int j=0; j<num_p_; j++)
      outArgs.setSupports(MEB::OUT_ARG_DgDp, i, j, model_outArgs.supports(MEB::OUT_ARG_DgDp, i, j));
  }

  return outArgs;
}


template<class Scalar>
void ROLSolver<Scalar>::evalModelImpl(
  const Thyra::ModelEvaluatorBase::InArgs<Scalar> &inArgs,
  const Thyra::ModelEvaluatorBase::OutArgs<Scalar> &outArgs
  ) const
{
  std::vector<std::string> output;
  Teuchos::RCP<ROL::BoundConstraint<double> > boundConstraint;
  if(boundConstrained) {
    Teuchos::Array<Teuchos::RCP<const Thyra::VectorBase<double>>> p_lo_vecs(num_p_);
    Teuchos::Array<Teuchos::RCP<const Thyra::VectorBase<double>>> p_up_vecs(num_p_);
    //double eps_bound = rolParams.get<double>("epsilon bound", 1e-6);
    for (auto i = 0; i < num_p_; ++i) {
      p_lo_vecs[i] = model_->getLowerBounds().get_p(p_indices_[i]);
      p_up_vecs[i] = model_->getUpperBounds().get_p(p_indices_[i]);
      TEUCHOS_TEST_FOR_EXCEPTION((p_lo_vecs[i] == Teuchos::null) || (p_up_vecs[i] == Teuchos::null), Teuchos::Exceptions::InvalidParameter,
          std::endl << "Error in Piro::PerformROLAnalysis: " <<
          "Lower and/or Upper bounds pointers are null, cannot perform bound constrained optimization"<<std::endl);
    }
    Teuchos::RCP<Thyra::VectorBase<double>> p_lo = Thyra::defaultProductVector<double>(p_space_, p_lo_vecs());
    Teuchos::RCP<Thyra::VectorBase<double>> p_up = Thyra::defaultProductVector<double>(p_space_, p_up_vecs());

    //ROL::Thyra_BoundConstraint<double> boundConstraint(p_lo->clone_v(), p_up->clone_v(), eps_bound);
    boundConstraint = rcp( new ROL::Bounds<double>(ROL::makePtr<ROL::ThyraVector<double> >(p_lo), ROL::makePtr<ROL::ThyraVector<double> >(p_up)));
  }
  
  //this is for testing the PrimalScaledThyraVector. At the moment the scaling is set to 1, so it is not changing the dot product
  Teuchos::RCP<Thyra::VectorBase<double> > scaling_vector_x = x_->clone_v();
  ::Thyra::put_scalar<double>( 1.0, scaling_vector_x.ptr());
  //::Thyra::randomize<double>( 0.5, 2.0, scaling_vector_x.ptr());
  ROL::PrimalScaledThyraVector<double> rol_x_primal(x_, scaling_vector_x);
//#ifdef HAVE_PIRO_TEKO
//  bool removeMeanOfTheRHS = hessianDotProductList.get("Remove Mean Of The Right-hand Side",false);
//  ROL::PrimalHessianScaledThyraVector<double> rol_p_primal(p, H, invH, removeMeanOfTheRHS);
//#else
  Teuchos::RCP<Thyra::VectorBase<double> > scaling_vector_p = result_->clone_v();
  ::Thyra::put_scalar<double>( 1.0, scaling_vector_p.ptr());
  ROL::PrimalScaledThyraVector<double> rol_p_primal(result_, scaling_vector_p);
//#endif


    int return_status = 0;

    if ( useFullSpace ) {
      //ROL::Vector_SimOpt<double> sopt_vec(ROL::makePtrFromRef(rol_x),ROL::makePtrFromRef(rol_p));
      ROL::Vector_SimOpt<double> sopt_vec(ROL::makePtrFromRef(rol_x_primal),ROL::makePtrFromRef(rol_p_primal));
      auto r_ptr = rol_x_ptr_->clone();
      double tol = 1e-5;
      constr_ptr_->solve(*r_ptr,*rol_x_ptr_,*rol_p_ptr_,tol);
	  auto pl =rolParams.sublist("ROL Options"); 
      if(boundConstrained) {
        ROL::BoundConstraint<double> u_bnd(*rol_x_ptr_);
        ROL::Ptr<ROL::BoundConstraint<double> > bnd = ROL::makePtr<ROL::BoundConstraint_SimOpt<double> >(ROL::makePtrFromRef(u_bnd),boundConstraint);
        ROL::OptimizationProblem<double> prob(obj_ptr_, ROL::makePtrFromRef(sopt_vec), bnd, constr_ptr_, r_ptr);
        ROL::OptimizationSolver<double> optSolver(prob, pl);
        optSolver.solve(*out);
        return_status = optSolver.getAlgorithmState()->statusFlag;
      } else {
        ROL::OptimizationProblem<double> prob(obj_ptr_, ROL::makePtrFromRef(sopt_vec), constr_ptr_, r_ptr);
        ROL::OptimizationSolver<double> optSolver(prob, pl);
        optSolver.solve(*out);
        return_status = optSolver.getAlgorithmState()->statusFlag;
      }
    } else {
      if(boundConstrained)
        output = algo_->run(rol_p_primal, *reduced_obj_ptr_, *boundConstraint, print_, *out);
      else
        output = algo_->run(rol_p_primal, *reduced_obj_ptr_, print_, *out);
      return_status = algo_->getState()->statusFlag;
    }
	
  for ( unsigned i = 0; i < output.size(); i++ ) {
    *out << output[i];
  }
}

} // namespace SiYuan

#endif
